Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing non-default modules to pipeline #188

Merged
merged 3 commits into from
Aug 16, 2022

Conversation

pcuenca
Copy link
Member

@pcuenca pcuenca commented Aug 16, 2022

Addresses #183.

Override modules are recognized and replaced in the pipeline. However, no check is performed about mismatched classes yet. This is because the override module is already instantiated (see https://github.com/huggingface/diffusers/blob/main/src/diffusers/configuration_utils.py#L223), and init_dict in https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipeline_utils.py#L151 no longer has a library_name or a class_name - just the instantiated module.

I'm looking at a way to detect a class mismatch so we can fail more gracefully.

Override modules are recognized and replaced in the pipeline. However,
no check is performed about mismatched classes yet. This is because the
override module is already instantiated and we have no library or class
name to compare against.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 16, 2022

The documentation is not available anymore as the PR was closed or merged.

@pcuenca
Copy link
Member Author

pcuenca commented Aug 16, 2022

I tried the following (as a quick hack before refactoring):

for name, module in passed_class_obj.items():
    # TODO: verify that the module class belongs to one of the supported classes
    library_name, class_name = config_dict[name]
    library = importlib.import_module(library_name)
    loadable_classes = LOADABLE_CLASSES[library_name]
    class_candidates = {c: getattr(library, c) for c in loadable_classes.keys()}
    for class_name, class_candidate in class_candidates.items():
        if isinstance(module, class_candidate):
            init_kwargs[name] = module
    # Remove it even if not found, as it's not appropriate
    init_dict.pop(name)

However, if we pass a scheduler instance to vae, the isinstance check succeeds for class candidate SchedulerMixin, because SchedulerMixin is part of the supported classes.

If we want to really verify this, I think we should create a more fine-grained mapping from module keys to supported classes, instead of checking all the loadable/importable classes in the library.

Copy link
Member

@anton-l anton-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, although I would add at least one test too 😅

@patrickvonplaten
Copy link
Contributor

Just added a test, all other tests pass

@patrickvonplaten
Copy link
Contributor

@anton-l - feel free to merge and then maybe also add it manually to the release notes quickly :-)

@patrickvonplaten
Copy link
Contributor

Merging!

@patrickvonplaten patrickvonplaten merged commit 513f1fb into main Aug 16, 2022
@patrickvonplaten patrickvonplaten deleted the pipeline-non-default-modules branch August 16, 2022 15:25
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Allow passing non-default modules to pipeline.

Override modules are recognized and replaced in the pipeline. However,
no check is performed about mismatched classes yet. This is because the
override module is already instantiated and we have no library or class
name to compare against.

* up

* add test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants